/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once
#include <vector>

#include <sparta/DisjointUnionAbstractDomain.h>

#include "Show.h"
#include "SignedConstantDomain.h"
#include "TemplateUtil.h"
#include "TypeUtil.h"

class DexField;
class DexMethod;

using StringDomain = sparta::ConstantAbstractDomain<const DexString*>;
using ConstantClassObjectDomain =
    sparta::ConstantAbstractDomain<const DexType*>;
using AttrDomain =
    sparta::DisjointUnionAbstractDomain<SignedConstantDomain,
                                        StringDomain,
                                        ConstantClassObjectDomain>;

enum TriState { True, False, Unknown };

class tristate_runtime_equals_visitor : public boost::static_visitor<TriState> {
 public:
  TriState operator()(const SignedConstantDomain& scd_left,
                      const SignedConstantDomain& scd_right) const {
    auto cst_left = scd_left.get_constant();
    auto cst_right = scd_right.get_constant();
    if (!(cst_left && cst_right)) {
      return TriState::Unknown;
    }
    return *cst_left == *cst_right ? TriState::True : TriState::False;
  }

  // ConstantClassObjectDomain and StringDomains are equal iff their
  // respective constants are equal.
  template <
      typename Constant,
      typename = typename std::enable_if_t<
          template_util::contains<Constant, const DexString*, const DexType*>::
              value>>
  TriState operator()(
      const sparta::ConstantAbstractDomain<Constant>& d1,
      const sparta::ConstantAbstractDomain<Constant>& d2) const {
    if (!(d1.is_value() && d2.is_value())) {
      return TriState::Unknown;
    }
    return *d1.get_constant() == *d2.get_constant() ? TriState::True
                                                    : TriState::False;
  }

  template <typename Domain, typename OtherDomain>
  TriState operator()(const Domain&, const OtherDomain&) const {
    return TriState::Unknown;
  }
};

/**
 * Object with immutable primitive attributes.
 *
 * For instance, enum objects may have other non final instance fields, but they
 * always have constant ordinal and name. Boxed integers are
 * constant. Another instance is type-tag field that's generated by
 * Class Merging.
 *
 * #clang-format off
 * an_enum_object {
 *    `Ljava/lang/Enum;.ordinal:()I` return an int constant.
 *    `Ljava/lang/Enum;.name:()Ljava/lang/String;` return a string constant.
 * }
 *
 * a_boxed_integer_object {
 *    `Ljava/lang/Integer;.intValue:()I` return an int constant.
 * }
 *
 * a_class_merging_shape_object {
 *  final int type_tag;  // is an immutable field.
 * }
 * #clang-format off
 */
struct ImmutableAttr {
  struct Attr {
    enum Kind { Method, Field } kind;
    union Val {
      explicit Val(DexField* f) : field(f) {}
      explicit Val(DexMethod* m) : method(m) {}
      DexField* field;
      DexMethod* method;
    } val;
    // Only used by test cases.
    Attr() : kind{Method}, val((DexMethod*)nullptr) {}
    explicit Attr(DexField* f) : kind(Field), val(f) {
      always_assert(!f->is_def() || (!is_static(f) && is_final(f)));
    }
    explicit Attr(DexMethod* m) : kind(Method), val(m) {
      if (m->is_def()) {
        always_assert(!is_static(m) && !is_constructor(m));
      }
    }

    bool is_method() const { return kind == Method; }
    bool is_field() const { return kind == Field; }

    // Accessing the non-active union member is undefined behavior.
    uintptr_t as_uintptr_t() const {
      uintptr_t tmp;
      static_assert(sizeof(uintptr_t) == sizeof(Val));
      memcpy(&tmp, &val, sizeof(uintptr_t));
      return tmp;
    }

    bool operator==(const Attr& other) const {
      return kind == other.kind && as_uintptr_t() == other.as_uintptr_t();
    }
    bool operator!=(const Attr& other) const { return !operator==(other); }
    // Compare pointer address for simplicity. The comparison is used
    // for keeping the constructed attributes of an object in order and easy to
    // be compared.
    bool operator<(const Attr& other) const {
      return as_uintptr_t() < other.as_uintptr_t();
    }
  } attr;
  AttrDomain value;

  ImmutableAttr(const Attr& attr, const AttrDomain& value)
      : attr(attr), value(value) {}
  ImmutableAttr(const Attr& attr, const SignedConstantDomain& value)
      : attr(attr), value(value) {}

  ImmutableAttr(const Attr& attr, const StringDomain& value)
      : attr(attr), value(value) {}

  ImmutableAttr(const Attr& attr, const ConstantClassObjectDomain& value)
      : attr(attr), value(value) {}

  TriState runtime_equals(const ImmutableAttr& other) const {
    return AttrDomain::apply_visitor(tristate_runtime_equals_visitor(), value,
                                     other.value);
  }

  bool value_is_constant() const {
    return AttrDomain::apply_visitor(tristate_runtime_equals_visitor(),
                                     value,
                                     value) == TriState::True;
  }

  bool operator==(const ImmutableAttr& other) const {
    return attr == other.attr && value == other.value;
  }

  bool same_key(const ImmutableAttr& other) const { return attr == other.attr; }

  friend std::ostream& operator<<(std::ostream& out, const ImmutableAttr& x) {
    if (x.attr.is_field()) {
      out << "f:" << x.attr.val.field->str();
    } else {
      out << "m:" << x.attr.val.method->str();
    }
    out << "=" << x.value;
    return out;
  }
};

struct ObjectWithImmutAttr {
  // This is true only for cached boxed objects. When it's true, it means the
  // attributes contain all the instance fields of the type and their runtime
  // equality can be determined.
  bool jvm_cached_singleton;
  const DexType* type;
  // The attributes contains part of the instance fields of the type which
  // should be already sorted by Attr.
  std::vector<ImmutableAttr> attributes;

  ObjectWithImmutAttr(const DexType* type, uint32_t size)
      : jvm_cached_singleton(false), type(type) {
    attributes.reserve(size);
  }

  /**
   * Check Java object reference equality.
   * Example:
   * 1. False: Integer{1} and Integer{2} is not equal.
   * 2. Unknown: Two boxed Integer object with intValue be 1, they are equal iff
   * they are cached, otherwise their equality is not determined.
   * 3. True: Cached boxed Integer{1} is equal to another cached boxed
   * Integer{1}.
   */
  TriState runtime_equals(const ObjectWithImmutAttr& other) const {
    if (type != other.type) {
      if (jvm_cached_singleton && other.jvm_cached_singleton) {
        return TriState::False;
      }
      // Can do more type checking on the two types but might not worthy.
      return TriState::Unknown;
    }
    size_t i = 0, j = 0;
    bool all_equal = true;
    for (; i < attributes.size() && j < other.attributes.size();) {
      const auto& attr1 = attributes[i];
      const auto& attr2 = other.attributes[j];
      if (attr1.attr == attr2.attr) {
        auto attr_equality = attr1.runtime_equals(attr2);
        if (attr_equality == TriState::False) {
          return TriState::False;
        } else if (attr_equality == TriState::Unknown) {
          all_equal = false;
        }
        i++;
        j++;
      } else if (attr1.attr < attr2.attr) {
        all_equal = false;
        i++;
      } else {
        all_equal = false;
        j++;
      }
    }
    if (i < attributes.size() || j < other.attributes.size()) {
      return TriState::Unknown;
    }
    return all_equal && jvm_cached_singleton && other.jvm_cached_singleton
               ? TriState::True
               : TriState::Unknown;
  }

  TriState same_type(const ObjectWithImmutAttr& other) const {
    if (type == other.type) {
      return TriState::True;
    }
    // Can do more type checking on the two types but might not worthy.
    return TriState::Unknown;
  }

  bool same_attrs(const ObjectWithImmutAttr& other) const {
    if (same_type(other) != TriState::True ||
        attributes.size() != other.attributes.size()) {
      return false;
    }
    for (size_t idx = 0; idx < attributes.size(); idx++) {
      auto& attr1 = attributes[idx];
      const auto& attr2 = other.attributes[idx];
      if (attr1.attr != attr2.attr) {
        return false;
      }
    }
    return true;
  }

  bool leq(const ObjectWithImmutAttr& other) const {
    redex_assert(type == other.type);
    if (jvm_cached_singleton && !other.jvm_cached_singleton) {
      return false;
    }
    if (attributes.size() > other.attributes.size()) {
      return false;
    }
    for (size_t idx = 0; idx < attributes.size(); idx++) {
      auto& attr1 = attributes[idx];
      const auto& attr2 = other.attributes[idx];
      if (attr1.attr != attr2.attr) {
        return false;
      }
      if (!attr1.value.leq(attr2.value)) {
        return false;
      }
    }
    return true;
  }

  void join_with(const ObjectWithImmutAttr& other) {
    redex_assert(type == other.type);
    bool is_all_constant = true;
    for (size_t idx = 0; idx < attributes.size(); idx++) {
      auto& attr1 = attributes[idx];
      const auto& attr2 = other.attributes[idx];
      attr1.value.join_with(attr2.value);
      is_all_constant = is_all_constant && attr1.value_is_constant();
    }
    jvm_cached_singleton =
        jvm_cached_singleton && other.jvm_cached_singleton && is_all_constant;
  }

  void meet_with(const ObjectWithImmutAttr& other) {
    redex_assert(type == other.type);
    for (size_t idx = 0; idx < attributes.size(); idx++) {
      auto& attr1 = attributes[idx];
      const auto& attr2 = other.attributes[idx];
      attr1.value.meet_with(attr2.value);
    }
    jvm_cached_singleton &= other.jvm_cached_singleton;
  }

  bool operator==(const ObjectWithImmutAttr& other) const {
    if (jvm_cached_singleton != other.jvm_cached_singleton ||
        type != other.type || attributes.size() != other.attributes.size()) {
      return false;
    }
    for (size_t idx = 0; idx < attributes.size(); idx++) {
      const auto& attr1 = attributes[idx];
      const auto& attr2 = other.attributes[idx];
      if (attr1.attr != attr2.attr || attr1.value != attr2.value) {
        return false;
      }
    }
    return true;
  }

  template <typename ValueType>
  void write_value(const ImmutableAttr::Attr& attr, ValueType value) {
#ifndef NDEBUG
    // Insertions are supposed to be in order. Thus a comparison check against
    // the last element is enough.
    always_assert_log(attributes.empty() || attributes.back().attr < attr,
                      "%s is written before, is it real final attribute?",
                      [&]() {
                        auto& att = attributes.back();
                        if (att.attr.is_method()) {
                          return show(att.attr.val.method);
                        } else {
                          return show(att.attr.val.field);
                        }
                      }()
                          .c_str());
#endif
    attributes.push_back(ImmutableAttr(attr, value));
  }

  bool empty() const { return attributes.empty(); }

  boost::optional<const AttrDomain> get_value(const DexMethod* method) const {
    for (const auto& attr : attributes) {
      if (attr.attr.is_method() && attr.attr.val.method == method) {
        return attr.value;
      }
    }
    return boost::none;
  }

  boost::optional<const AttrDomain> get_value(const DexField* field) const {
    for (const auto& attr : attributes) {
      if (attr.attr.is_field() && attr.attr.val.field == field) {
        return attr.value;
      }
    }
    return boost::none;
  }

  friend std::ostream& operator<<(std::ostream& out,
                                  const ObjectWithImmutAttr& x) {
    out << (x.jvm_cached_singleton ? "[c]" : "")
        << type::get_simple_name(x.type) << "{";
    for (auto& attr : x.attributes) {
      out << attr << ",";
    }
    out << "}";
    return out;
  }
};

/**
 * This domain stores an object with **immutable** attributes. The
 * attributes must be immutable, for example, final primitive instance fields
 * that are never changed after initialization, regardless of whether the object
 * may escape.
 * +----------------+-----------------------------------+
 * |                | Boxed primitive objects           |
 * |                +----------------------+            |
 * | Normal Objects | Cached Boxed objects |            |
 * | T1{x}          +~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ +~ ~ ~ ~ ~ ~ +
 * | T2{y}          | Integer{1}           | Integer{1} |
 * |                +~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ +~ ~ ~ ~ ~ ~ +
 * |                | Integer{0}           | Integer{0} |
 * |                +----------------------+~ ~ ~ ~ ~ ~ +
 * |                |                                   |
 * +----------------+----------------------+------------+
 */
class ObjectWithImmutAttrDomain final
    : public sparta::AbstractDomain<ObjectWithImmutAttrDomain> {
 public:
  ObjectWithImmutAttrDomain() { this->set_to_top(); }

  explicit ObjectWithImmutAttrDomain(ObjectWithImmutAttr&& obj)
      : m_kind(sparta::AbstractValueKind::Value),
        m_value(std::make_unique<ObjectWithImmutAttr>(std::move(obj))) {}

  ObjectWithImmutAttrDomain(const ObjectWithImmutAttrDomain& other) {
    m_kind = other.m_kind;
    if (other.m_value) {
      m_value = std::make_unique<ObjectWithImmutAttr>(*other.m_value);
    }
  }

  ObjectWithImmutAttrDomain(ObjectWithImmutAttrDomain&& other) = default;

  ObjectWithImmutAttrDomain& operator=(const ObjectWithImmutAttrDomain& other) {
    m_kind = other.m_kind;
    if (other.m_value) {
      m_value = std::make_unique<ObjectWithImmutAttr>(*other.m_value);
    }
    return *this;
  }

  ObjectWithImmutAttrDomain& operator=(ObjectWithImmutAttrDomain&& other) =
      default;

  boost::optional<ObjectWithImmutAttr> get_constant() const {
    return m_value ? boost::make_optional(*m_value) : boost::none;
  }

  bool is_bottom() const { return m_kind == sparta::AbstractValueKind::Bottom; }

  bool is_value() const { return m_kind == sparta::AbstractValueKind::Value; }

  bool is_top() const { return m_kind == sparta::AbstractValueKind::Top; }

  void set_to_bottom() {
    m_kind = sparta::AbstractValueKind::Bottom;
    m_value = nullptr;
  }

  void set_to_top() {
    m_kind = sparta::AbstractValueKind::Top;
    m_value = nullptr;
  }

  bool leq(const ObjectWithImmutAttrDomain& other) const {
    if (is_bottom()) {
      return true;
    }
    if (is_top()) {
      return other.is_top();
    }
    if (other.is_top()) {
      return true;
    }
    return m_value->type == other.m_value->type && m_value->leq(*other.m_value);
  }

  bool equals(const ObjectWithImmutAttrDomain& other) const {
    if (is_bottom()) {
      return other.is_bottom();
    }
    if (is_top()) {
      return other.is_top();
    }
    if (other.m_kind != sparta::AbstractValueKind::Value) {
      return false;
    }
    return *m_value == *other.m_value;
  }

  void join_with(const ObjectWithImmutAttrDomain& other) {
    if (is_top() || other.is_bottom()) {
      return;
    }
    if (other.is_top()) {
      set_to_top();
      return;
    }
    if (is_bottom()) {
      m_kind = other.m_kind;
      m_value = std::make_unique<ObjectWithImmutAttr>(*other.m_value);
      return;
    }
    if (m_value->same_attrs(*other.m_value)) {
      m_value->join_with(*other.m_value);
    } else {
      set_to_top();
    }
  }

  void widen_with(const ObjectWithImmutAttrDomain& other) { join_with(other); }

  void meet_with(const ObjectWithImmutAttrDomain& other) {
    if (is_bottom() || other.is_top()) {
      return;
    }
    if (other.is_bottom()) {
      set_to_bottom();
      return;
    }
    if (is_top()) {
      m_kind = other.m_kind;
      m_value = std::make_unique<ObjectWithImmutAttr>(*other.m_value);
      return;
    }
    auto equality = m_value->runtime_equals(*other.m_value);
    if (equality == TriState::True) {
      return;
    } else if (equality == TriState::False) {
      set_to_bottom();
    } else {
      always_assert(equality == TriState::Unknown);
      if (m_value->same_attrs(*other.m_value)) {
        m_value->meet_with(*other.m_value);
        for (auto& attr : m_value->attributes) {
          if (attr.value.is_bottom()) {
            set_to_bottom();
            return;
          }
        }
        return;
      }
      set_to_top();
    }
  }

  void narrow_with(const ObjectWithImmutAttrDomain& other) { meet_with(other); }

  friend std::ostream& operator<<(std::ostream& out,
                                  const ObjectWithImmutAttrDomain& x) {
    using namespace sparta;
    switch (x.m_kind) {
    case sparta::AbstractValueKind::Bottom: {
      out << "_|_";
      break;
    }
    case sparta::AbstractValueKind::Top: {
      out << "T";
      break;
    }
    case sparta::AbstractValueKind::Value: {
      out << *x.m_value;
      break;
    }
    }
    return out;
  }

 private:
  sparta::AbstractValueKind m_kind;
  std::unique_ptr<ObjectWithImmutAttr> m_value;
};
